第14章 RAG知识库系统实现
学习目标
- 掌握完整RAG知识库系统的设计与实现流程
- 学习构建高质量文档处理与索引流水线
- 理解RAG系统中各组件的协同工作机制
- 了解如何评估和优化RAG知识库系统性能
RAG知识库系统概述
RAG知识库系统是一种将检索增强与知识管理相结合的智能系统,可以帮助用户高效地获取、理解和应用大量专业知识。
参考课程视频中的内容
核心组件与流程
一个完整的RAG知识库系统通常包含以下核心组件:
- 文档处理与索引系统:负责文档的采集、预处理、分割和索引
- 检索系统:支持多种检索策略,找到最相关的信息
- 生成系统:使用检索结果作为上下文,生成高质量回答
- 用户交互界面:提供直观的交互方式,支持问答和反馈
- 评估与监控系统:持续评估系统性能,收集反馈以优化系统
文档处理与索引系统实现
1. 文档采集与预处理
首先,我们需要构建一个灵活的文档加载和预处理流水线:
python
from langchain.document_loaders import (
PyPDFLoader,
TextLoader,
CSVLoader,
UnstructuredMarkdownLoader,
WebBaseLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
import re
class DocumentProcessor:
def __init__(self):
# 支持的文件类型和对应的加载器
self.loaders = {
".pdf": PyPDFLoader,
".txt": TextLoader,
".csv": CSVLoader,
".md": UnstructuredMarkdownLoader
}
# 初始化文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
def load_directory(self, directory_path):
"""加载整个目录的文档"""
documents = []
for root, _, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
try:
file_documents = self.load_file(file_path)
documents.extend(file_documents)
except Exception as e:
print(f"Error loading {file_path}: {e}")
return documents
def load_file(self, file_path):
"""加载单个文件"""
ext = os.path.splitext(file_path)[1].lower()
if ext in self.loaders:
loader = self.loaders[ext](file_path)
return loader.load()
else:
raise ValueError(f"Unsupported file type: {ext}")
def load_web(self, urls):
"""加载网页内容"""
loader = WebBaseLoader(urls)
return loader.load()
def preprocess_text(self, text):
"""文本预处理"""
# 移除多余空白
text = re.sub(r'\s+', ' ', text).strip()
# 移除特殊字符
text = re.sub(r'[^\w\s.,;:!?()[\]{}"\'-]', '', text)
return text
def split_documents(self, documents):
"""文档分割"""
return self.text_splitter.split_documents(documents)
def process_documents(self, source, source_type="directory"):
"""完整处理流程"""
# 加载文档
if source_type == "directory":
documents = self.load_directory(source)
elif source_type == "file":
documents = self.load_file(source)
elif source_type == "web":
documents = self.load_web(source)
else:
raise ValueError(f"Unsupported source type: {source_type}")
# 预处理和分割
for doc in documents:
doc.page_content = self.preprocess_text(doc.page_content)
return self.split_documents(documents)
# 使用文档处理器
processor = DocumentProcessor()
documents = processor.process_documents("./knowledge_base", source_type="directory")
print(f"Processed {len(documents)} document chunks")
2. 高级文档分割策略
对于不同类型的文档,我们可以使用更智能的分割策略:
python
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
MarkdownTextSplitter,
PythonCodeTextSplitter,
HTMLTextSplitter
)
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import DeepSeek
class SmartDocumentSplitter:
def __init__(self, llm):
self.llm = llm
# 初始化不同类型的分割器
self.splitters = {
"default": RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
),
"markdown": MarkdownTextSplitter(
chunk_size=1000,
chunk_overlap=200
),
"python": PythonCodeTextSplitter(
chunk_size=1000,
chunk_overlap=200
),
"html": HTMLTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
}
# 创建语义分割链
semantic_split_template = """
分析以下文本并找出最佳的语义分割点,将其分成多个连贯且相对独立的段落。
每个段落应保持内部的语义完整性,最好是在自然的主题转换点进行分割。
原始文本:
{text}
请提供分割点的位置(字符索引),每个索引占一行:
"""
self.semantic_splitter = LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["text"],
template=semantic_split_template
)
)
def split_by_type(self, document, doc_type="default"):
"""根据文档类型选择分割器"""
splitter = self.splitters.get(doc_type, self.splitters["default"])
return splitter.split_documents([document])
def split_by_semantic(self, document, max_chunk_size=1000):
"""基于语义内容进行智能分割"""
text = document.page_content
# 对于短文本,直接返回
if len(text) <= max_chunk_size:
return [document]
# 使用LLM找出语义分割点
try:
split_response = self.semantic_splitter.run(text=text[:min(len(text), 4000)])
split_indices = [int(idx.strip()) for idx in split_response.split("\n") if idx.strip().isdigit()]
# 确保分割点有效
split_indices = [idx for idx in split_indices if 0 < idx < len(text)]
split_indices = sorted(split_indices)
# 如果没有有效分割点,回退到默认分割
if not split_indices:
return self.split_by_type(document)
# 根据分割点创建文档块
chunks = []
start_idx = 0
for idx in split_indices:
chunk_text = text[start_idx:idx].strip()
if chunk_text:
chunk_doc = document.copy()
chunk_doc.page_content = chunk_text
chunks.append(chunk_doc)
start_idx = idx
# 添加最后一个块
if start_idx < len(text):
chunk_text = text[start_idx:].strip()
if chunk_text:
chunk_doc = document.copy()
chunk_doc.page_content = chunk_text
chunks.append(chunk_doc)
return chunks
except Exception as e:
print(f"Error in semantic splitting: {e}")
return self.split_by_type(document)
def split_mixed_content(self, document):
"""处理混合内容文档"""
# 检测文档类型
content = document.page_content.lower()
if "```python" in content or "```py" in content:
# 包含Python代码
return self.split_by_type(document, "python")
elif content.count("#") > 5 or "---" in content:
# 可能是Markdown
return self.split_by_type(document, "markdown")
elif "<html" in content or "<div" in content or "<p>" in content:
# 可能是HTML
return self.split_by_type(document, "html")
else:
# 尝试语义分割
return self.split_by_semantic(document)
# 使用智能分割器
llm = DeepSeek(api_key="your-api-key")
smart_splitter = SmartDocumentSplitter(llm)
# 处理文档
document = documents[0] # 假设已有documents列表
chunks = smart_splitter.split_mixed_content(document)
print(f"Split into {len(chunks)} chunks")
3. 元数据提取与增强
为文档添加丰富的元数据信息,可以显著提升检索效果:
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import hashlib
import datetime
class MetadataEnhancer:
def __init__(self, llm):
self.llm = llm
# 创建主题提取链
topic_extract_template = """
分析以下文本,提取3-5个最能代表其核心主题的关键词或短语。
文本:
{text}
关键主题(以逗号分隔):
"""
self.topic_extractor = LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["text"],
template=topic_extract_template
)
)
# 创建摘要生成链
summary_template = """
为以下文本生成一个简洁的摘要(50字以内)。
文本:
{text}
摘要:
"""
self.summarizer = LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["text"],
template=summary_template
)
)
def extract_metadata(self, document):
"""提取和增强文档元数据"""
# 复制原有元数据
metadata = document.metadata.copy() if hasattr(document, "metadata") else {}
text = document.page_content
# 基础元数据
metadata["doc_id"] = hashlib.md5(text.encode()).hexdigest()
metadata["char_count"] = len(text)
metadata["word_count"] = len(text.split())
metadata["processed_date"] = datetime.datetime.now().isoformat()
# 提取主题
try:
topics_text = self.topic_extractor.run(text=text[:min(len(text), 3000)])
topics = [topic.strip() for topic in topics_text.split(",")]
metadata["topics"] = topics
except Exception as e:
print(f"Error extracting topics: {e}")
# 生成摘要
try:
metadata["summary"] = self.summarizer.run(text=text[:min(len(text), 3000)])
except Exception as e:
print(f"Error generating summary: {e}")
# 检测语言(简单实现)
chinese_char_ratio = len([c for c in text if '\u4e00' <= c <= '\u9fff']) / max(len(text), 1)
metadata["language"] = "zh" if chinese_char_ratio > 0.1 else "en"
# 更新文档元数据
document.metadata = metadata
return document
def enhance_batch(self, documents):
"""批量增强文档元数据"""
enhanced_docs = []
for doc in documents:
enhanced_docs.append(self.extract_metadata(doc))
return enhanced_docs
# 使用元数据增强器
enhancer = MetadataEnhancer(llm)
enhanced_documents = enhancer.enhance_batch(chunks)
print(f"Enhanced {len(enhanced_documents)} documents with metadata")
4. 混合索引构建
构建同时支持向量检索和关键词检索的混合索引:
python
from langchain.vectorstores import Chroma
from langchain.embeddings import DeepSeekEmbeddings
from langchain.retrievers.bm25 import BM25Retriever
import pickle
import os
class HybridIndexBuilder:
def __init__(self, embeddings, persist_directory="./hybrid_index"):
self.embeddings = embeddings
self.persist_directory = persist_directory
os.makedirs(persist_directory, exist_ok=True)
def build_vector_index(self, documents):
"""构建向量索引"""
vector_db = Chroma.from_documents(
documents=documents,
embedding=self.embeddings,
persist_directory=os.path.join(self.persist_directory, "vector_db")
)
vector_db.persist()
return vector_db
def build_keyword_index(self, documents):
"""构建关键词索引"""
bm25_retriever = BM25Retriever.from_documents(documents)
# 保存BM25检索器
with open(os.path.join(self.persist_directory, "bm25.pkl"), "wb") as f:
pickle.dump(bm25_retriever, f)
return bm25_retriever
def build_hybrid_index(self, documents):
"""构建混合索引"""
print("Building vector index...")
vector_db = self.build_vector_index(documents)
print("Building keyword index...")
bm25_retriever = self.build_keyword_index(documents)
# 保存文档ID映射
doc_ids = {doc.metadata.get("doc_id", i): i for i, doc in enumerate(documents)}
with open(os.path.join(self.persist_directory, "doc_ids.pkl"), "wb") as f:
pickle.dump(doc_ids, f)
print("Hybrid index built successfully")
return {
"vector_db": vector_db,
"bm25_retriever": bm25_retriever,
"doc_ids": doc_ids
}
def load_hybrid_index(self):
"""加载已有的混合索引"""
# 加载向量数据库
vector_db = Chroma(
persist_directory=os.path.join(self.persist_directory, "vector_db"),
embedding_function=self.embeddings
)
# 加载BM25检索器
with open(os.path.join(self.persist_directory, "bm25.pkl"), "rb") as f:
bm25_retriever = pickle.load(f)
# 加载文档ID映射
with open(os.path.join(self.persist_directory, "doc_ids.pkl"), "rb") as f:
doc_ids = pickle.load(f)
return {
"vector_db": vector_db,
"bm25_retriever": bm25_retriever,
"doc_ids": doc_ids
}
# 初始化嵌入模型
embeddings = DeepSeekEmbeddings(api_key="your-api-key")
# 构建混合索引
index_builder = HybridIndexBuilder(embeddings)
indices = index_builder.build_hybrid_index(enhanced_documents)
# 加载已有索引
# loaded_indices = index_builder.load_hybrid_index()
高性能检索系统实现
基于前一节学习的混合检索策略,我们设计一个完整的高性能检索系统:
python
from langchain.retrievers import EnsembleRetriever
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import numpy as np
class AdvancedRetrievalSystem:
def __init__(self, hybrid_indices, llm, embeddings):
self.vector_db = hybrid_indices["vector_db"]
self.bm25_retriever = hybrid_indices["bm25_retriever"]
self.doc_ids = hybrid_indices["doc_ids"]
self.llm = llm
self.embeddings = embeddings
# 创建检索器
self.setup_retrievers()
# 创建查询优化器
self.setup_query_optimizer()
def setup_retrievers(self):
"""设置各种检索器"""
# 向量检索器
self.vector_retriever = self.vector_db.as_retriever(
search_kwargs={"k": 5}
)
# 集成检索器
self.ensemble_retriever = EnsembleRetriever(
retrievers=[self.bm25_retriever, self.vector_retriever],
weights=[0.3, 0.7]
)
def setup_query_optimizer(self):
"""设置查询优化器"""
query_optimizer_template = """
请分析以下用户查询,并将其改写为更适合检索系统的形式。
添加关键术语,消除歧义,并确保查询的意图清晰。
用户查询: {query}
改写后的查询:
"""
self.query_optimizer = LLMChain(
llm=self.llm,
prompt=PromptTemplate(
input_variables=["query"],
template=query_optimizer_template
)
)
def filter_similar_documents(self, documents, threshold=0.95):
"""过滤相似文档"""
if not documents:
return []
# 计算所有文档的嵌入
embeddings_list = []
for doc in documents:
embeddings_list.append(self.embeddings.embed_document(doc.page_content))
# 计算相似度矩阵
similarity_matrix = np.zeros((len(documents), len(documents)))
for i in range(len(documents)):
for j in range(i, len(documents)):
if i == j:
similarity_matrix[i][j] = 1.0
else:
# 计算余弦相似度
similarity = np.dot(embeddings_list[i], embeddings_list[j]) / (
np.linalg.norm(embeddings_list[i]) * np.linalg.norm(embeddings_list[j])
)
similarity_matrix[i][j] = similarity
similarity_matrix[j][i] = similarity
# 贪婪选择不相似的文档
selected_indices = []
for i in range(len(documents)):
# 检查当前文档是否与已选择的任何文档过于相似
is_similar = False
for selected_idx in selected_indices:
if similarity_matrix[i][selected_idx] > threshold:
is_similar = True
break
if not is_similar:
selected_indices.append(i)
# 返回过滤后的文档
return [documents[i] for i in selected_indices]
def retrieve(self, query, strategy="hybrid", optimize_query=True, filter_similar=True):
"""执行检索"""
# 查询优化
if optimize_query:
try:
optimized_query = self.query_optimizer.run(query=query)
query = optimized_query
except Exception as e:
print(f"Error optimizing query: {e}")
# 选择检索策略
if strategy == "vector":
results = self.vector_retriever.get_relevant_documents(query)
elif strategy == "keyword":
results = self.bm25_retriever.get_relevant_documents(query)
elif strategy == "hybrid":
results = self.ensemble_retriever.get_relevant_documents(query)
else:
raise ValueError(f"Unknown retrieval strategy: {strategy}")
# 过滤相似文档
if filter_similar and results:
results = self.filter_similar_documents(results)
return results
# 创建检索系统
retrieval_system = AdvancedRetrievalSystem(indices, llm, embeddings)
# 执行检索
results = retrieval_system.retrieve(
"深度学习在自然语言处理中的应用",
strategy="hybrid",
optimize_query=True,
filter_similar=True
)
print(f"Retrieved {len(results)} documents")
for i, doc in enumerate(results):
print(f"Document {i+1}:")
print(f"Summary: {doc.metadata.get('summary', 'N/A')}")
print(f"Topics: {doc.metadata.get('topics', 'N/A')}")
print(f"Content: {doc.page_content[:100]}...")
print("-" * 50)
下一节我们将继续讨论RAG系统的生成部分和评估方法。
思考题
在构建RAG知识库时,如何确定最优的文档分割粒度?这个决策会如何影响检索效果?
对于包含大量表格、图表和代码的技术文档,应该采取什么特殊的处理策略?
元数据提取对RAG系统的检索性能有什么影响?哪些类型的元数据对不同领域的文档特别重要?
在实际应用中,如何平衡索引构建的时间复杂度和检索性能?有哪些优化技巧?